from __future__ import annotations

import pytest

pyiceberg = pytest.importorskip("pyiceberg")

import contextlib

import pyarrow as pa
from pyiceberg.avro.file import AvroFile

import daft
from tests.conftest import assert_df_equals


@contextlib.contextmanager
def table_written_by_pyiceberg(local_pyiceberg_catalog):
    schema = pa.schema([("col", pa.int64()), ("mapCol", pa.map_(pa.int32(), pa.string()))])

    data = {"col": [1, 2, 3], "mapCol": [[(1, "foo"), (2, "bar")], [(3, "baz")], [(4, "foobar")]]}
    arrow_table = pa.Table.from_pydict(data, schema=schema)
    table_name = "pyiceberg.map_table"
    try:
        table = local_pyiceberg_catalog.create_table(table_name, schema=schema)
        table.append(arrow_table)
        yield table_name
    except Exception as e:
        raise e
    finally:
        local_pyiceberg_catalog.drop_table(table_name)


@contextlib.contextmanager
def table_written_by_daft(local_pyiceberg_catalog):
    schema = pa.schema([("col", pa.int64()), ("mapCol", pa.map_(pa.int32(), pa.string()))])

    data = {"col": [1, 2, 3], "mapCol": [[(1, "foo"), (2, "bar")], [(3, "baz")], [(4, "foobar")]]}
    arrow_table = pa.Table.from_pydict(data, schema=schema)
    table_name = "pyiceberg.map_table"
    try:
        table = local_pyiceberg_catalog.create_table(table_name, schema=schema)
        df = daft.from_arrow(arrow_table)
        df.write_iceberg(table, mode="overwrite")
        table.refresh()
        yield table_name
    except Exception as e:
        raise e
    finally:
        local_pyiceberg_catalog.drop_table(table_name)


@pytest.mark.integration()
def test_pyiceberg_written_catalog(local_iceberg_catalog):
    catalog_name, local_pyiceberg_catalog = local_iceberg_catalog
    with table_written_by_pyiceberg(local_pyiceberg_catalog) as catalog_table_name:
        df = daft.read_table(f"{catalog_name}.{catalog_table_name}")
        daft_pandas = df.to_pandas()
        iceberg_pandas = local_pyiceberg_catalog.load_table(catalog_table_name).scan().to_arrow().to_pandas()
        assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[])


@pytest.mark.integration()
@pytest.mark.skip
def test_daft_written_catalog(local_iceberg_catalog):
    catalog_name, local_pyiceberg_catalog = local_iceberg_catalog
    with table_written_by_daft(local_pyiceberg_catalog) as catalog_table_name:
        df = daft.read_table(f"{catalog_name}.{catalog_table_name}")
        daft_pandas = df.to_pandas()
        iceberg_pandas = local_pyiceberg_catalog.load_table(catalog_table_name).scan().to_arrow().to_pandas()
        assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[])


def get_data_files(table):
    """Get the locations of data files for a given table."""
    table.refresh()

    current_snapshot = table.current_snapshot()
    if not current_snapshot:
        raise ValueError("Table has no current snapshot")

    data_files = []

    manifest_list_file = table.io.new_input(current_snapshot.manifest_list)
    with AvroFile(manifest_list_file) as reader:
        for record in reader:
            manifest_path = record.manifest_path
            manifest_input = table.io.new_input(manifest_path)
            with AvroFile(manifest_input) as manifest_reader:
                for manifest_record in manifest_reader:
                    if manifest_record.status == 1:
                        data_files.append(manifest_record.data_file.file_path)

    return data_files


@pytest.mark.integration()
def test_daft_custom_location(local_iceberg_catalog):
    _, local_pyiceberg_catalog = local_iceberg_catalog
    schema = pa.schema([("data", pa.string())])

    data = {"data": ["foo", "bar", "baz"]}
    arrow_table = pa.Table.from_pydict(data, schema=schema)
    table_name = "pyiceberg.table_custom_location"

    # First, create the table in the default location
    try:
        table = local_pyiceberg_catalog.create_table(table_name, schema=schema, properties={})
        df = daft.from_arrow(arrow_table)
        df.write_iceberg(
            table,
            mode="overwrite",
        )
        table.refresh()

        base_table_location = table.metadata.location
        custom_data_location = base_table_location + "/custom-suffix"

        data_files = get_data_files(table)
        assert len(data_files) > 0
        for file_path in data_files:
            assert not file_path.startswith(custom_data_location), "File found in custom location"

    finally:
        local_pyiceberg_catalog.drop_table(table_name)

    # Then, re-create it in a custom location
    try:
        table = local_pyiceberg_catalog.create_table(
            table_name, schema=schema, properties={"write.data.path": custom_data_location}
        )
        df = daft.from_arrow(arrow_table)
        df.write_iceberg(
            table,
            mode="overwrite",
        )
        table.refresh()

        data_files = get_data_files(table)
        assert len(data_files) > 0
        for file_path in data_files:
            assert file_path.startswith(custom_data_location), f"File found outside custom location: {file_path}"
    finally:
        local_pyiceberg_catalog.drop_table(table_name)
